{# ----------------------------------------------------------------------------
 # SymForce - Copyright 2022, Skydio, Inc.
 # This source code is under the Apache 2.0 license found in the LICENSE file.
 # ---------------------------------------------------------------------------- #}

/* ----------------------------------------------------------------------------
 * SymForce - Copyright 2022, Skydio, Inc.
 * This source code is under the Apache 2.0 license found in the LICENSE file.
 * ---------------------------------------------------------------------------- */

/**
 * Tests for C++ camera types. Mostly checking all the templates compile since
 * the math is tested comprehensively in symbolic form.
 */

#include <random>

#include <Eigen/Dense>
#include <fmt/ostream.h>
#include <spdlog/spdlog.h>

{% for cls in all_types %}
#include <sym/{{ camelcase_to_snakecase(cls.__name__) }}.h>
{% endfor %}
#include <sym/camera.h>
#include <sym/posed_camera.h>
#include <sym/pose3.h>
#include <sym/rot3.h>
#include <sym/ops/group_ops.h>
#include <sym/ops/lie_group_ops.h>
#include <sym/util/epsilon.h>

#include <catch2/catch_template_test_macros.hpp>
#include <catch2/catch_test_macros.hpp>
#include <catch2/generators/catch_generators.hpp>
#include <catch2/generators/catch_generators_range.hpp>

template <typename T>
T CalFromData(std::initializer_list<typename T::Scalar> data) {
  Eigen::Matrix<typename T::Scalar, sym::StorageOps<T>::StorageDim(), 1> data_vec;
  size_t i = 0;
  for (const auto& x : data) {
    data_vec[i] = x;
    i++;
  }
  return T(data_vec);
}

template <typename T>
struct CamCals {
  static std::vector<T> Get();
};

template <typename Scalar>
struct CamCals<sym::LinearCameraCal<Scalar>> {
  using T = sym::LinearCameraCal<Scalar>;
  static std::vector<T> Get() {
    return {
        CalFromData<T>({90, 90, 60, 80}),
        CalFromData<T>({380, 380, 320, 240}),
        CalFromData<T>({500, 500, 1000, 800}),
    };
  }
};

template <typename Scalar>
struct CamCals<sym::DoubleSphereCameraCal<Scalar>> {
  using T = sym::DoubleSphereCameraCal<Scalar>;
  static std::vector<T> Get() {
    return {
        CalFromData<T>({233, 233, 388, 388, 1.0, 0.0}),
        CalFromData<T>({833, 834, 388, 388, 2.559, -1.328}),
        CalFromData<T>({313, 313, 638, 514, -0.18, 0.59}),
    };
  }
};

template <typename Scalar>
struct CamCals<sym::SphericalCameraCal<Scalar>> {
  using T = sym::SphericalCameraCal<Scalar>;
  static std::vector<T> Get() {
    return {
        CalFromData<T>({234, 234, 388, 391, M_PI, 0, 0, 0, 0}),
        CalFromData<T>({234, 234, 388, 391, M_PI, 0.035, -0.025, 0.0070, -0.0015}),
        CalFromData<T>({250, 250, 700, 500, M_PI, -0.01, -0.016, 0.0020, 0.0030}),
    };
  }
};

template <typename Scalar>
struct CamCals<sym::PolynomialCameraCal<Scalar>> {
  using T = sym::PolynomialCameraCal<Scalar>;
  static std::vector<T> Get() {
    return {
        CalFromData<T>({234, 234, 388, 391, M_PI/3, 0, 0, 0}),
        CalFromData<T>({234, 234, 388, 391, M_PI/3, 0.035, -0.025, 0.0070}),
        CalFromData<T>({250, 250, 700, 500, M_PI/3, -0.01, -0.016, 0.0020}),
    };
  }
};

template <typename Scalar>
struct CamCals<sym::EquirectangularCameraCal<Scalar>> {
  using T = sym::EquirectangularCameraCal<Scalar>;
  static std::vector<T> Get() {
    return {
        CalFromData<T>({90, 90, 60, 80}),
        CalFromData<T>({380, 380, 320, 240}),
        CalFromData<T>({500, 500, 1000, 800}),
    };
  }
};

template <typename Scalar>
struct CamCals<sym::ATANCameraCal<Scalar>> {
  using T = sym::ATANCameraCal<Scalar>;
  static std::vector<T> Get() {
    return {
        CalFromData<T>({90, 90, 60, 80, 0.68}),
        CalFromData<T>({380, 380, 320, 240, 0.35}),
        CalFromData<T>({500, 500, 1000, 800, 0.21}),
    };
  }
};

TEMPLATE_TEST_CASE("Test storage ops", "[cam_package]", {{ ", ".join(cpp_cam_types) }}) {
  using T = TestType;
  const T& cam_cal = GENERATE(from_range(CamCals<T>::Get()));

  using Scalar = typename T::Scalar;

  spdlog::debug("*** Testing StorageOps: {} ***", cam_cal);

  constexpr int32_t storage_dim = sym::StorageOps<T>::StorageDim();
  CHECK(cam_cal.Data().rows() == storage_dim);
  CHECK(cam_cal.Data().cols() == 1);

  std::array<Scalar, storage_dim> arr;
  cam_cal.ToStorage(arr.data());
  for (int i = 0; i < static_cast<int>(arr.size()); ++i) {
    CHECK(arr[i] == cam_cal.Data()[i]);
  }

  const T cam_cal2 = sym::StorageOps<T>::FromStorage(arr.data());
  CHECK(cam_cal.Data() == cam_cal2.Data());
  arr[0] = 2.1;
  const T cam_cal3 = sym::StorageOps<T>::FromStorage(arr.data());
  CHECK(cam_cal.Data() != cam_cal3.Data());
}

TEMPLATE_TEST_CASE("Test group ops", "[cam_package]", {{ ", ".join(cpp_cam_types) }}) {
  using T = TestType;

  const T identity = sym::GroupOps<T>::Identity();
  spdlog::debug("*** Testing GroupOps: {} ***", identity);

  using SelfJacobian = typename sym::GroupOps<T>::SelfJacobian;

  const auto tangent_dim = sym::LieGroupOps<T>::TangentDim();
  CHECK(tangent_dim == SelfJacobian::RowsAtCompileTime);
  CHECK(tangent_dim == SelfJacobian::ColsAtCompileTime);

  CHECK(identity.IsApprox(sym::GroupOps<T>::Compose(identity, identity), 1e-9));
  CHECK(identity.IsApprox(sym::GroupOps<T>::Inverse(identity), 1e-9));
  CHECK(identity.IsApprox(sym::GroupOps<T>::Between(identity, identity), 1e-9));

  SelfJacobian jacobian1 = SelfJacobian::Zero();
  SelfJacobian jacobian2 = SelfJacobian::Zero();
  CHECK(identity.IsApprox(sym::GroupOps<T>::ComposeWithJacobians(identity, identity, &jacobian1, &jacobian2), 1e-9));
  CHECK(identity.IsApprox(sym::GroupOps<T>::InverseWithJacobian(identity, &jacobian1), 1e-9));
  CHECK(identity.IsApprox(sym::GroupOps<T>::BetweenWithJacobians(identity, identity, &jacobian1, &jacobian2), 1e-9));
}

TEMPLATE_TEST_CASE("Test Lie group ops", "[cam_package]", {{ ", ".join(cpp_cam_types) }}) {
  using T = TestType;

  const T identity = sym::GroupOps<T>::Identity();
  spdlog::debug("*** Testing LieGroupOps: {} ***", identity);

  CHECK(sym::LieGroupOps<T>::TangentDim() == sym::LieGroupOps<T>::TangentVec::RowsAtCompileTime);

  using TangentVec = typename sym::LieGroupOps<T>::TangentVec;
  CHECK(identity.IsApprox(sym::LieGroupOps<T>::FromTangent(TangentVec::Zero(), sym::kDefaultEpsilond), sym::kDefaultEpsilond));

  CHECK(sym::LieGroupOps<T>::ToTangent(identity, sym::kDefaultEpsilond).norm() < sym::kDefaultEpsilond);
  CHECK(identity.IsApprox(sym::LieGroupOps<T>::Retract(identity, TangentVec::Zero(), sym::kDefaultEpsilond), sym::kDefaultEpsilond));
  CHECK(sym::LieGroupOps<T>::LocalCoordinates(identity, identity, sym::kDefaultEpsilond).norm() < sym::kDefaultEpsilond);
}

TEMPLATE_TEST_CASE("Test project and deproject", "[cam_package]", {{ ", ".join(fully_implemented_cpp_cam_types) }}) {
  using T = TestType;
  const T& cam_cal = GENERATE(from_range(CamCals<T>::Get()));

  using Scalar = typename T::Scalar;
  const Scalar epsilon = 1e-6; // For preventing degenerate numerical cases (e.g. division by zero)
  const Scalar tolerance = 10.0 * epsilon; // For checking approx. equality

  spdlog::debug("*** Testing projection model: {} ***", cam_cal);

  std::mt19937 gen(42);
  // Generate pixels around principal point
  std::uniform_real_distribution<Scalar> pixel_x_dist(0.0, 2.0 * cam_cal.Data()[2]);
  std::uniform_real_distribution<Scalar> pixel_y_dist(0.0, 2.0 * cam_cal.Data()[3]);
  for(int i = 0; i < 10; i++) {
    Eigen::Matrix<Scalar, 2, 1> pixel;
    pixel << pixel_x_dist(gen), pixel_y_dist(gen);

    Scalar is_valid_camera_ray;
    Scalar is_valid_pixel;
    const Eigen::Matrix<Scalar, 3, 1> camera_ray =
        cam_cal.CameraRayFromPixel(pixel, epsilon, &is_valid_camera_ray);
    const Eigen::Matrix<Scalar, 2, 1> pixel_reprojected =
        cam_cal.PixelFromCameraPoint(camera_ray, epsilon, &is_valid_pixel);
    if (is_valid_camera_ray == 1 && is_valid_pixel == 1) {
      CHECK(pixel.isApprox(pixel_reprojected, tolerance));
    }
  }
}

TEMPLATE_TEST_CASE("Test Camera class", "[cam_package]", {{ ", ".join(fully_implemented_cpp_cam_types) }}) {
  using T = TestType;
  const T& cam_cal = GENERATE(from_range(CamCals<T>::Get()));

  using Scalar = typename T::Scalar;
  const Scalar epsilon = 1e-6; // For preventing degenerate numerical cases (e.g. division by zero)

  spdlog::debug("*** Testing Camera class with calibration: {} ***", cam_cal);

  // Assume the principal point is at the center of the image
  Eigen::Matrix<int, 2, 1> image_size;
  image_size << int(2.0 * cam_cal.Data()[2]), int(2.0 * cam_cal.Data()[3]);

  const sym::Camera<T> cam(cam_cal, image_size);
  Scalar is_valid;

  CHECK(cam.Calibration() == cam_cal);
  CHECK(cam.ImageSize() == image_size);

  CHECK(cam.FocalLength() == cam_cal.FocalLength());
  CHECK(cam.PrincipalPoint() == cam_cal.PrincipalPoint());
  const auto focal_length_expected = cam_cal.Data().template head<2>();
  CHECK(cam.FocalLength() == focal_length_expected);
  const auto principal_point_expected = cam_cal.Data().template segment<2>(2);
  CHECK(cam.PrincipalPoint() == principal_point_expected);

  // Check a pixel that's out of the image
  Eigen::Matrix<Scalar, 2, 1> invalid_pixel;
  invalid_pixel << -1, -1;
  cam.CameraRayFromPixel(invalid_pixel, epsilon, &is_valid);
  CHECK(is_valid == 0);
  CHECK(cam.MaybeCheckInView(invalid_pixel) == 0);
  CHECK(sym::Camera<T>::InView(invalid_pixel, image_size) == 0);

  // Check a point that's at the center of the image
  Eigen::Matrix<Scalar, 2, 1> valid_pixel;
  valid_pixel << image_size[0] / 2.0, image_size[1] / 2.0;
  const Eigen::Matrix<Scalar, 3, 1> valid_camera_point =
      cam.CameraRayFromPixel(valid_pixel, epsilon, &is_valid);
  CHECK(is_valid == 1);
  CHECK(cam.MaybeCheckInView(valid_pixel) == 1);
  CHECK(sym::Camera<T>::InView(valid_pixel, image_size) == 1);

  // Project a point into the camera and check validity
  cam.PixelFromCameraPoint(valid_camera_point, epsilon, &is_valid);
  CHECK(is_valid == 1);
}

TEMPLATE_TEST_CASE("Test PosedCamera class", "[cam_package]", {{ ", ".join(fully_implemented_cpp_cam_types) }}) {
  using T = TestType;
  const T& cam_cal = GENERATE(from_range(CamCals<T>::Get()));

  using Scalar = typename T::Scalar;
  // For preventing degenerate numerical cases (e.g. division by zero)
  const Scalar epsilon = sym::kDefaultEpsilon<Scalar>;
  // For assessing approximate equality
  const Scalar tolerance = 50.0 * epsilon;

  spdlog::debug("*** Testing PosedCamera class with calibration: {} ***", cam_cal);

  std::mt19937 gen(42);
  // Generate pixels around principal point
  std::uniform_real_distribution<Scalar> pixel_x_dist(0.0, 2.0 * cam_cal.Data()[2]);
  std::uniform_real_distribution<Scalar> pixel_y_dist(0.0, 2.0 * cam_cal.Data()[3]);
  std::uniform_real_distribution<Scalar> range_dist(1.0, 5.0);
  for(int i = 0; i < 10; i++) {
    const sym::Pose3<Scalar> pose = sym::Random<sym::Pose3<Scalar>>(gen);
    const sym::PosedCamera<T> cam(pose, cam_cal);

    CHECK(cam.Pose() == pose);

    Eigen::Matrix<Scalar, 2, 1> pixel;
    pixel << pixel_x_dist(gen), pixel_y_dist(gen);
    const Scalar range_to_point = range_dist(gen);

    Scalar is_valid_global_point;
    Scalar is_valid_pixel;
    const Eigen::Matrix<Scalar, 3, 1> global_point =
        cam.GlobalPointFromPixel(pixel, range_to_point, epsilon, &is_valid_global_point);
    const Eigen::Matrix<Scalar, 2, 1> pixel_reprojected =
        cam.PixelFromGlobalPoint(global_point, epsilon, &is_valid_pixel);
    if (is_valid_global_point == 1 && is_valid_pixel == 1) {
      CHECK(pixel.isApprox(pixel_reprojected, tolerance));
    }

    const sym::Pose3<Scalar> pose2 = sym::Random<sym::Pose3<Scalar>>(gen);
    const sym::PosedCamera<T> cam2(pose2, cam_cal);
    Scalar is_valid_warped_pixel;
    const Scalar inverse_range = 1 / (range_to_point + epsilon);
    const Eigen::Matrix<Scalar, 2, 1> warped_pixel =
        cam.WarpPixel(pixel, inverse_range, cam2, epsilon, &is_valid_warped_pixel);
    const Scalar range2 = (cam2.Pose().Inverse() * global_point).norm();
    const Scalar inverse_range2 = 1 / (range2 + epsilon);
    Scalar is_valid_rewarped_pixel;
    const Eigen::Matrix<Scalar, 2, 1> rewarped_pixel =
        cam2.WarpPixel(warped_pixel, inverse_range2, cam, epsilon, &is_valid_rewarped_pixel);
    if (is_valid_warped_pixel == 1 && is_valid_rewarped_pixel == 1) {
      CHECK(pixel.isApprox(rewarped_pixel, std::sqrt(epsilon)));
    }
  }
}

TEMPLATE_TEST_CASE("Test ATANCameraCal named constructor", "[cam_package]", double, float) {
  using Scalar = TestType;

  const Eigen::Matrix<Scalar, 2, 1> focal_length = Eigen::Matrix<Scalar, 2, 1>(100, 200);
  const Eigen::Matrix<Scalar, 2, 1> principal_point = Eigen::Matrix<Scalar, 2, 1>(300, 400);
  const Scalar omega = 0.5;
  sym::ATANCameraCal<Scalar> cal(focal_length, principal_point, omega);
  CHECK(cal.FocalLength() == focal_length);
  CHECK(cal.PrincipalPoint() == principal_point);
  CHECK(cal.Data()[4] == omega);
}

TEMPLATE_TEST_CASE("Test DoubleSphereCameraCal named constructor", "[cam_package]", double, float) {
  using Scalar = TestType;

  const Eigen::Matrix<Scalar, 2, 1> focal_length = Eigen::Matrix<Scalar, 2, 1>(100, 200);
  const Eigen::Matrix<Scalar, 2, 1> principal_point = Eigen::Matrix<Scalar, 2, 1>(300, 400);
  const Scalar xi = 0.1;
  const Scalar alpha = 0.2;
  sym::DoubleSphereCameraCal<Scalar> cal(focal_length, principal_point, xi, alpha);
  CHECK(cal.FocalLength() == focal_length);
  CHECK(cal.PrincipalPoint() == principal_point);
  CHECK(cal.Data()[4] == xi);
  CHECK(cal.Data()[5] == alpha);
}

TEMPLATE_TEST_CASE("Test EquirectangularCameraCal named constructor", "[cam_package]", double, float) {
  using Scalar = TestType;

  const Eigen::Matrix<Scalar, 2, 1> focal_length = Eigen::Matrix<Scalar, 2, 1>(100, 200);
  const Eigen::Matrix<Scalar, 2, 1> principal_point = Eigen::Matrix<Scalar, 2, 1>(300, 400);
  sym::EquirectangularCameraCal<Scalar> cal(focal_length, principal_point);
  CHECK(cal.FocalLength() == focal_length);
  CHECK(cal.PrincipalPoint() == principal_point);
}

TEMPLATE_TEST_CASE("Test LinearCameraCal named constructor", "[cam_package]", double, float) {
  using Scalar = TestType;

  const Eigen::Matrix<Scalar, 2, 1> focal_length = Eigen::Matrix<Scalar, 2, 1>(100, 200);
  const Eigen::Matrix<Scalar, 2, 1> principal_point = Eigen::Matrix<Scalar, 2, 1>(300, 400);
  sym::LinearCameraCal<Scalar> cal(focal_length, principal_point);
  CHECK(cal.FocalLength() == focal_length);
  CHECK(cal.PrincipalPoint() == principal_point);
}

TEMPLATE_TEST_CASE("Test PolynomialCameraCal named constructor", "[cam_package]", double, float) {
  using Scalar = TestType;

  const Eigen::Matrix<Scalar, 2, 1> focal_length = Eigen::Matrix<Scalar, 2, 1>(100, 200);
  const Eigen::Matrix<Scalar, 2, 1> principal_point = Eigen::Matrix<Scalar, 2, 1>(300, 400);
  const Scalar critical_undistorted_radius = 0.5;
  const Eigen::Matrix<Scalar, 3, 1> distortion_coeffs =
      Eigen::Matrix<Scalar, 3, 1>(0.1, 0.2, 0.3);
  sym::PolynomialCameraCal<Scalar> cal(focal_length, principal_point, critical_undistorted_radius,
                                      distortion_coeffs);
  CHECK(cal.FocalLength() == focal_length);
  CHECK(cal.PrincipalPoint() == principal_point);
  CHECK(cal.Data()[4] == critical_undistorted_radius);
  CHECK(cal.Data().template tail<3>() == distortion_coeffs);
}

TEMPLATE_TEST_CASE("Test SphericalCameraCal named constructor", "[cam_package]", double, float) {
  using Scalar = TestType;

  const Eigen::Matrix<Scalar, 2, 1> focal_length = Eigen::Matrix<Scalar, 2, 1>(100, 200);
  const Eigen::Matrix<Scalar, 2, 1> principal_point = Eigen::Matrix<Scalar, 2, 1>(300, 400);
  const Scalar critical_theta = 0.5;
  const Eigen::Matrix<Scalar, 4, 1> distortion_coeffs =
      Eigen::Matrix<Scalar, 4, 1>(0.1, 0.2, 0.3, 0.4);
  sym::SphericalCameraCal<Scalar> cal(focal_length, principal_point, critical_theta,
                                      distortion_coeffs);
  CHECK(cal.FocalLength() == focal_length);
  CHECK(cal.PrincipalPoint() == principal_point);
  CHECK(cal.Data()[4] == critical_theta);
  CHECK(cal.Data().template tail<4>() == distortion_coeffs);
}
